-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends #21820
Conversation
Summary of ChangesHello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances Keras 3 by adding adaptive average and max pooling layers for 2D spatial data. These new layers allow users to specify a target output size, with the pooling kernel and stride automatically adjusted, providing greater flexibility in network architectures, particularly for tasks requiring fixed-size feature maps regardless of input dimensions. The implementation prioritizes the JAX backend while ensuring seamless integration with other Keras backends. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers, along with their corresponding backend operations. The changes include the layer definitions, JAX backend implementations, ops API, and comprehensive tests. The layer APIs and tests are well-designed. However, the JAX backend implementation has significant performance issues due to the use of Python loops, which are not JIT-compatible. There are also opportunities to improve code quality by removing dead code and reducing duplication. My review provides specific feedback on these points.
keras/src/backend/jax/nn.py
Outdated
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply average pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Average over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.mean(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Average over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.mean(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of adaptive pooling uses Python for loops to iterate over output positions. This is an anti-pattern in JAX as it prevents JIT compilation and leads to very poor performance, especially for larger inputs or output sizes. The computation should be expressed using JAX's vectorized operations or JIT-compatible loops like lax.fori_loop to achieve good performance. A fully vectorized einsum-based approach for average pooling, or a lax.fori_loop over output pixels for both pooling types, would be significantly more performant. This comment also applies to the adaptive_max_pool implementation.
keras/src/backend/jax/nn.py
Outdated
| def _adaptive_pool_start_index(output_idx, output_size, input_size): | ||
| """Calculate start index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32) | ||
|
|
||
|
|
||
| def _adaptive_pool_end_index(output_idx, output_size, input_size): | ||
| """Calculate end index for adaptive pooling (PyTorch compatible).""" | ||
| return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype( | ||
| jnp.int32 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras/src/backend/jax/nn.py
Outdated
| def adaptive_avg_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive average pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply average pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Average over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.mean(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Average over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.mean(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| def adaptive_max_pool( | ||
| inputs, output_size, data_format="channels_last", name=None | ||
| ): | ||
| """ | ||
| Adaptive max pooling for JAX backend (PyTorch-compatible). | ||
| """ | ||
| # Convert output_size to tuple | ||
| spatial_dims = inputs.ndim - 2 | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) * spatial_dims | ||
| else: | ||
| output_size = tuple(output_size) | ||
|
|
||
| # Get spatial shape | ||
| if data_format == "channels_last": | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[-1] | ||
| spatial_shape = inputs.shape[1:-1] | ||
| else: # channels_first | ||
| batch_size = inputs.shape[0] | ||
| channels = inputs.shape[1] | ||
| spatial_shape = inputs.shape[2:] | ||
|
|
||
| if len(output_size) != 2: | ||
| raise NotImplementedError( | ||
| "Only 2D adaptive pooling is currently supported" | ||
| ) | ||
|
|
||
| out_h, out_w = output_size | ||
| in_h, in_w = spatial_shape | ||
|
|
||
| # Build output by iterating over output positions | ||
| result_list = [] | ||
|
|
||
| for i in range(out_h): | ||
| for j in range(out_w): | ||
| # Calculate pooling region for this output position | ||
| start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32) | ||
| end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32) | ||
| start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32) | ||
| end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32) | ||
|
|
||
| # Extract region and apply max pooling | ||
| if data_format == "channels_last": | ||
| region = inputs[:, start_h:end_h, start_w:end_w, :] | ||
| # Max over spatial dimensions (axis 1, 2) | ||
| pooled = jnp.max(region, axis=(1, 2)) | ||
| else: # channels_first | ||
| region = inputs[:, :, start_h:end_h, start_w:end_w] | ||
| # Max over spatial dimensions (axis 2, 3) | ||
| pooled = jnp.max(region, axis=(2, 3)) | ||
|
|
||
| result_list.append(pooled) | ||
|
|
||
| # Stack results: (out_h*out_w, batch, channels) | ||
| output = jnp.stack(result_list, axis=0) | ||
|
|
||
| # Reshape and transpose to correct output shape | ||
| if data_format == "channels_last": | ||
| # (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 0, 1, 3)) | ||
| else: # channels_first | ||
| # (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w) | ||
| output = output.reshape(out_h, out_w, batch_size, channels) | ||
| output = jnp.transpose(output, (2, 3, 0, 1)) | ||
|
|
||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions adaptive_avg_pool and adaptive_max_pool are nearly identical, with the only difference being the pooling operation (jnp.mean vs jnp.max). This code duplication can be avoided by creating a generic _adaptive_pool helper function that takes the pooling function as an argument. This would improve maintainability and reduce redundancy.
For example:
def _adaptive_pool(inputs, output_size, data_format, pool_op):
# ... common setup code ...
for i in range(out_h):
for j in range(out_w):
# ... common region calculation ...
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
pooled = pool_op(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
pooled = pool_op(region, axis=(2, 3))
result_list.append(pooled)
# ... common reshape and transpose code ...
return output
def adaptive_avg_pool(inputs, output_size, data_format="channels_last", name=None):
# ...
return _adaptive_pool(inputs, output_size, data_format, jnp.mean)
def adaptive_max_pool(inputs, output_size, data_format="channels_last", name=None):
# ...
return _adaptive_pool(inputs, output_size, data_format, jnp.max)Note that this refactoring suggestion still contains the performance issue mentioned in another comment. The primary goal here is to illustrate how to reduce code duplication.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21820 +/- ##
==========================================
- Coverage 82.66% 82.42% -0.25%
==========================================
Files 577 584 +7
Lines 59419 60113 +694
Branches 9313 9429 +116
==========================================
+ Hits 49121 49547 +426
- Misses 7898 8087 +189
- Partials 2400 2479 +79
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…X, NumPy, PyTorch, and TensorFlow backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds adaptive pooling support for JAX, TensorFlow, and PyTorch backends, which is a great addition. The implementation for JAX and TensorFlow uses a custom "Two-Pool Gather" algorithm, while the PyTorch implementation leverages native operations. The code is well-structured and includes corresponding unit tests.
My review focuses on improving maintainability by reducing code duplication in the backend implementations, ensuring user-facing elements like docstrings and error messages are clear and accurate, and maintaining code style consistency. I've provided several suggestions to address these points.
| n, l, c = inputs.shape | ||
| out_l = output_size[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names n, l, c, and out_l are quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability, with a few common exceptions like dim and num.1 Consider using more descriptive names like batch_size, length, channels, and output_length. This comment also applies to the other adaptive pooling functions in this file.
For example:
n, l, c = inputs.shape -> batch_size, length, channels = inputs.shape
out_l = output_size[0] -> output_length = output_size[0]
Style Guide References
Footnotes
-
The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g.,
attention_scoresinstead ofattn_scores. Short names are acceptable only for very common terms likedimornum. ↩
| # ---------- 1D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| small_pool_l = small_pool_l / small_l | ||
|
|
||
| big_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = big_pool_l / big_l | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| # ---------- 2D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- 3D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_d = small_pool_d / small_d | ||
|
|
||
| big_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_d = big_pool_d / big_d | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_d = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- Dispatcher ---------- | ||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_max_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_max_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_max_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_max_pool supports 1D, 2D, or 3D inputs only." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementations for adaptive_avg_pool{1,2,3}d and adaptive_max_pool{1,2,3}d are very similar, leading to significant code duplication. To improve maintainability, consider refactoring this code.
Here are a couple of suggestions:
- Create a helper function for each dimension (e.g.,
_adaptive_pool1d) that takes the pooling type ('avg'or'max') as an argument. This would halve the number of functions. - A more advanced refactoring would be to create a single generic n-dimensional pooling function that iterates over the spatial dimensions. This would further consolidate the logic for 1D, 2D, and 3D pooling into one place.
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for OpenVINO. " | ||
| "Use JAX or Torch backend." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is missing TensorFlow as a supported backend for adaptive pooling. Please update the message to include it for accuracy.
| raise NotImplementedError( | |
| "Adaptive pooling not implemented for OpenVINO. " | |
| "Use JAX or Torch backend." | |
| ) | |
| raise NotImplementedError( | |
| "Adaptive pooling not implemented for OpenVINO. " | |
| "Use JAX, TensorFlow, or Torch backend." | |
| ) |
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for OpenVINO. " | ||
| "Use JAX or Torch backend." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is missing TensorFlow as a supported backend for adaptive pooling. Please update the message to include it for accuracy.
| raise NotImplementedError( | |
| "Adaptive pooling not implemented for OpenVINO. " | |
| "Use JAX or Torch backend." | |
| ) | |
| raise NotImplementedError( | |
| "Adaptive pooling not implemented for OpenVINO. " | |
| "Use JAX, TensorFlow, or Torch backend." | |
| ) |
| static_shape = inputs.shape.as_list() | ||
| l_static = static_shape[1] | ||
| out_l = output_size[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names l_static and out_l are quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability.1 Consider using more descriptive names like static_length and output_length. This comment also applies to the other adaptive pooling functions in this file.
For example:
l_static = static_shape[1] -> static_length = static_shape[1]
out_l = output_size[0] -> output_length = output_size[0]
Style Guide References
Footnotes
-
The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g.,
attention_scoresinstead ofattn_scores. Short names are acceptable only for very common terms likedimornum. ↩
| raise TypeError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received output_size={} of type {}".format( | ||
| output_size, type(output_size) | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if not isinstance(output_size, int): | ||
| raise TypeError( | ||
| "`output_size` must be an integer. Received output_size={} " | ||
| "of type {}".format(output_size, type(output_size)) | ||
| ) | ||
| self.output_size = output_size | ||
| self.data_format = data_format or config.image_data_format() | ||
|
|
||
| if self.data_format not in {"channels_first", "channels_last"}: | ||
| raise ValueError( | ||
| "Invalid data_format: {}. Must be either 'channels_first' " | ||
| "or 'channels_last'.".format(self.data_format) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other new files in this PR, please use f-strings for these error messages instead of .format().
| if not isinstance(output_size, int): | |
| raise TypeError( | |
| "`output_size` must be an integer. Received output_size={} " | |
| "of type {}".format(output_size, type(output_size)) | |
| ) | |
| self.output_size = output_size | |
| self.data_format = data_format or config.image_data_format() | |
| if self.data_format not in {"channels_first", "channels_last"}: | |
| raise ValueError( | |
| "Invalid data_format: {}. Must be either 'channels_first' " | |
| "or 'channels_last'.".format(self.data_format) | |
| ) | |
| if not isinstance(output_size, int): | |
| raise TypeError( | |
| f"`output_size` must be an integer. Received: output_size={output_size} " | |
| f"of type {type(output_size)}" | |
| ) | |
| self.output_size = output_size | |
| self.data_format = data_format or config.image_data_format() | |
| if self.data_format not in {"channels_first", "channels_last"}: | |
| raise ValueError( | |
| f"Invalid data_format: {self.data_format}. Must be either 'channels_first' " | |
| f"or 'channels_last'." | |
| ) |
| if len(output_size) != 3: | ||
| raise ValueError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received: {}".format(output_size) | ||
| ) | ||
| self.output_size = tuple(output_size) | ||
| else: | ||
| raise TypeError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received: {} of type {}".format(output_size, type(output_size)) | ||
| ) | ||
|
|
||
| self.data_format = data_format or config.image_data_format() | ||
|
|
||
| if self.data_format not in {"channels_first", "channels_last"}: | ||
| raise ValueError( | ||
| "Invalid data_format: {}. Must be either 'channels_first' or " | ||
| "'channels_last'.".format(self.data_format) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other new files in this PR, please use f-strings for these error messages instead of .format().
| if len(output_size) != 3: | |
| raise ValueError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| "Received: {}".format(output_size) | |
| ) | |
| self.output_size = tuple(output_size) | |
| else: | |
| raise TypeError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| "Received: {} of type {}".format(output_size, type(output_size)) | |
| ) | |
| self.data_format = data_format or config.image_data_format() | |
| if self.data_format not in {"channels_first", "channels_last"}: | |
| raise ValueError( | |
| "Invalid data_format: {}. Must be either 'channels_first' or " | |
| "'channels_last'.".format(self.data_format) | |
| ) | |
| if len(output_size) != 3: | |
| raise ValueError( | |
| f"`output_size` must be an integer or tuple of 3 integers. " | |
| f"Received: {output_size}" | |
| ) | |
| self.output_size = tuple(output_size) | |
| else: | |
| raise TypeError( | |
| f"`output_size` must be an integer or tuple of 3 integers. " | |
| f"Received: output_size={output_size} of type {type(output_size)}" | |
| ) | |
| self.data_format = data_format or config.image_data_format() | |
| if self.data_format not in {"channels_first", "channels_last"}: | |
| raise ValueError( | |
| f"Invalid data_format: {self.data_format}. Must be either 'channels_first' or " | |
| f"'channels_last'." | |
| ) |
| """Adaptive max pooling operation. | ||
| Applies an adaptive max pooling operation that automatically computes the | ||
| kernel size and stride to pool the input to the specified `output_size`. | ||
| This operation is useful when you want a fixed output size regardless of | ||
| input size, commonly used in models like ResNet for global feature | ||
| extraction. | ||
| Args: | ||
| inputs: Tensor of rank 4. Input tensor of shape: | ||
| - If `data_format="channels_last"`: | ||
| `(batch_size, height, width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| `(batch_size, channels, height, width)`. | ||
| output_size: Integer or tuple/list of 2 integers, specifying the target | ||
| output spatial dimensions `(output_height, output_width)`. If a | ||
| single | ||
| integer is provided, the same value is used for both dimensions. | ||
| data_format: string, either `"channels_last"` or `"channels_first"`. | ||
| Defaults to the value found in your Keras config file at | ||
| `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. | ||
| Returns: | ||
| A tensor of rank 4 representing the adaptive max pooled result. | ||
| Example: | ||
| >>> x = np.random.rand(2, 64, 64, 3) | ||
| >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32)) | ||
| >>> y.shape | ||
| (2, 32, 32, 3) | ||
| >>> # Works with any input size | ||
| >>> x = np.random.rand(2, 100, 80, 3) | ||
| >>> y = keras.ops.adaptive_max_pool(x, output_size=7) | ||
| >>> y.shape | ||
| (2, 7, 7, 3) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring describes the 2D case, but this function is a dispatcher for 1D, 2D, and 3D pooling. Please update the docstring to be more general and include examples for other dimensions to avoid confusion for users.1
"""Adaptive max pooling operation for 1D, 2D, and 3D data.
This operation is useful when you want a fixed output size regardless of
input size.
Args:
inputs: Input tensor. Must be 3D, 4D, or 5D.
output_size: An integer or a tuple of integers, specifying the output
spatial dimensions.
data_format: string, either `"channels_last"` or `"channels_first"`.
Defaults to the value found in your Keras config file at
`~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
Returns:
A tensor representing the adaptive max pooled result.
Example:
**2D Example**
>>> x = np.random.rand(2, 64, 64, 3)
>>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32))
>>> y.shape
(2, 32, 32, 3)
**3D Example**
>>> x = np.random.rand(2, 32, 32, 32, 3)
>>> y = keras.ops.adaptive_max_pool(x, output_size=(16, 16, 16))
>>> y.shape
(2, 16, 16, 16, 3)
"""Style Guide References
Footnotes
-
Docstrings should be comprehensive and show examples for common use cases and key features to guide the user effectively. ↩
| """Adaptive average pooling operation. | ||
| Applies an adaptive average pooling operation that automatically | ||
| computes the | ||
| kernel size and stride to pool the input to the specified `output_size`. | ||
| This operation is useful when you want a fixed output size regardless of | ||
| input size, commonly used in models like ResNet for global feature | ||
| extraction. | ||
| Args: | ||
| inputs: Tensor of rank 4. Input tensor of shape: | ||
| - If `data_format="channels_last"`: | ||
| `(batch_size, height, width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| `(batch_size, channels, height, width)`. | ||
| output_size: Integer or tuple/list of 2 integers, specifying the target | ||
| output spatial dimensions `(output_height, output_width)`. If a | ||
| single | ||
| integer is provided, the same value is used for both dimensions. | ||
| data_format: string, either `"channels_last"` or `"channels_first"`. | ||
| Defaults to the value found in your Keras config file at | ||
| `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. | ||
| Returns: | ||
| A tensor of rank 4 representing the adaptive average pooled result. | ||
| Example: | ||
| >>> x = np.random.rand(2, 64, 64, 3) | ||
| >>> y = keras.ops.adaptive_avg_pool(x, output_size=(32, 32)) | ||
| >>> y.shape | ||
| (2, 32, 32, 3) | ||
| >>> # Works with any input size | ||
| >>> x = np.random.rand(2, 100, 80, 3) | ||
| >>> y = keras.ops.adaptive_avg_pool(x, output_size=7) | ||
| >>> y.shape | ||
| (2, 7, 7, 3) | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring describes the 2D case, but this function is a dispatcher for 1D, 2D, and 3D pooling. Please update the docstring to be more general and include examples for other dimensions to avoid confusion for users.1
"""Adaptive average pooling operation for 1D, 2D, and 3D data.
This operation is useful when you want a fixed output size regardless of
input size.
Args:
inputs: Input tensor. Must be 3D, 4D, or 5D.
output_size: An integer or a tuple of integers, specifying the output
spatial dimensions.
data_format: string, either `"channels_last"` or `"channels_first"`.
Defaults to the value found in your Keras config file at
`~/.keras/keras.json`. If never set, defaults to `"channels_last"`.
Returns:
A tensor representing the adaptive average pooled result.
Example:
**2D Example**
>>> x = np.random.rand(2, 64, 64, 3)
>>> y = keras.ops.adaptive_avg_pool(x, output_size=(32, 32))
>>> y.shape
(2, 32, 32, 3)
**3D Example**
>>> x = np.random.rand(2, 32, 32, 32, 3)
>>> y = keras.ops.adaptive_avg_pool(x, output_size=(16, 16, 16))
>>> y.shape
(2, 16, 16, 16, 3)
"""Style Guide References
Footnotes
-
Docstrings should be comprehensive and show examples for common use cases and key features to guide the user effectively. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces adaptive pooling (1D, 2D, and 3D) for JAX, TensorFlow, and PyTorch backends, along with the corresponding Keras layers and tests. The implementation for PyTorch leverages native operations, while for JAX and TensorFlow, a custom "Two-Pool Gather" algorithm is used. The changes are comprehensive and well-tested. My main feedback is to refactor the JAX and TensorFlow implementations to reduce significant code duplication, which will improve maintainability. I've also noted some minor issues with error messages and docstrings.
| def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| small_pool_l = small_pool_l / small_l | ||
|
|
||
| big_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = big_pool_l / big_l | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| # ---------- 2D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- 3D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_d = small_pool_d / small_d | ||
|
|
||
| big_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_d = big_pool_d / big_d | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_d = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- Dispatcher ---------- | ||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_max_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_max_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_max_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_max_pool supports 1D, 2D, or 3D inputs only." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementations for 1D, 2D, and 3D adaptive pooling for both avg and max operations contain a significant amount of duplicated code. This makes the code harder to read and maintain.
Consider refactoring this by creating a generalized helper function. This function could handle the pooling logic for a single dimension and could be parameterized for average vs. max pooling.
For example, you could have a helper:
_adaptive_pool_1d_single_dim(inputs, axis, output_dim, reduce_fn, init_val, normalize=False)
Then, the 2D and 3D functions can be implemented by composing this helper function for each spatial dimension. This would greatly reduce code duplication and improve maintainability.
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for OpenVINO. " | ||
| "Use JAX or Torch backend." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise NotImplementedError( | ||
| "Adaptive pooling not implemented for OpenVINO. " | ||
| "Use JAX or Torch backend." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise TypeError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received output_size={} of type {}".format( | ||
| output_size, type(output_size) | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other new layer files in this PR, please use an f-string for this error message.
| raise TypeError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| "Received output_size={} of type {}".format( | |
| output_size, type(output_size) | |
| ) | |
| ) | |
| raise TypeError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| f"Received: output_size={output_size} of type " | |
| f"{type(output_size)}" | |
| ) | |
| if not isinstance(output_size, int): | ||
| raise TypeError( | ||
| "`output_size` must be an integer. Received output_size={} " | ||
| "of type {}".format(output_size, type(output_size)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise ValueError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received: {}".format(output_size) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, please use an f-string for this error message. Also, consider raising a TypeError instead of a ValueError here, as the check is on the length of the output_size tuple, which relates to its structure/type in this context.
| raise ValueError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| "Received: {}".format(output_size) | |
| ) | |
| raise ValueError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| f"Received: output_size={output_size}" | |
| ) | |
| raise TypeError( | ||
| "`output_size` must be an integer or tuple of 3 integers. " | ||
| "Received: {} of type {}".format(output_size, type(output_size)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with other new layer files in this PR, please use an f-string for this error message.
| raise TypeError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| "Received: {} of type {}".format(output_size, type(output_size)) | |
| ) | |
| raise TypeError( | |
| "`output_size` must be an integer or tuple of 3 integers. " | |
| f"Received: output_size={output_size} of type {type(output_size)}" | |
| ) | |
| """Adaptive max pooling operation. | ||
| Applies an adaptive max pooling operation that automatically computes the | ||
| kernel size and stride to pool the input to the specified `output_size`. | ||
| This operation is useful when you want a fixed output size regardless of | ||
| input size, commonly used in models like ResNet for global feature | ||
| extraction. | ||
| Args: | ||
| inputs: Tensor of rank 4. Input tensor of shape: | ||
| - If `data_format="channels_last"`: | ||
| `(batch_size, height, width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| `(batch_size, channels, height, width)`. | ||
| output_size: Integer or tuple/list of 2 integers, specifying the target | ||
| output spatial dimensions `(output_height, output_width)`. If a | ||
| single | ||
| integer is provided, the same value is used for both dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for adaptive_max_pool is not entirely accurate. It states that the input is a 'Tensor of rank 4' and output_size is for 2D inputs. However, this function supports 1D, 2D, and 3D inputs (ranks 3, 4, and 5). Please update the docstring to reflect this, for example:
Args:
inputs: Tensor of rank 3, 4, or 5.
output_size: Integer or tuple/list of 1, 2, or 3 integers, specifying
the target output spatial dimensions. If a single integer is
provided, the same value is used for all spatial dimensions.| """Adaptive average pooling operation. | ||
| Applies an adaptive average pooling operation that automatically | ||
| computes the | ||
| kernel size and stride to pool the input to the specified `output_size`. | ||
| This operation is useful when you want a fixed output size regardless of | ||
| input size, commonly used in models like ResNet for global feature | ||
| extraction. | ||
| Args: | ||
| inputs: Tensor of rank 4. Input tensor of shape: | ||
| - If `data_format="channels_last"`: | ||
| `(batch_size, height, width, channels)`. | ||
| - If `data_format="channels_first"`: | ||
| `(batch_size, channels, height, width)`. | ||
| output_size: Integer or tuple/list of 2 integers, specifying the target | ||
| output spatial dimensions `(output_height, output_width)`. If a | ||
| single | ||
| integer is provided, the same value is used for both dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for adaptive_avg_pool is not entirely accurate. It states that the input is a 'Tensor of rank 4' and output_size is for 2D inputs. However, this function supports 1D, 2D, and 3D inputs (ranks 3, 4, and 5). Please update the docstring to reflect this, for example:
Args:
inputs: Tensor of rank 3, 4, or 5.
output_size: Integer or tuple/list of 1, 2, or 3 integers, specifying
the target output spatial dimensions. If a single integer is
provided, the same value is used for all spatial dimensions.
fix #21813
Add adaptive pooling support across major backends
This PR implements adaptive pooling for 1D, 2D, and 3D across the JAX, TensorFlow, and PyTorch backends.
For PyTorch, native adaptive pooling ops are used.
For JAX and TensorFlow, adaptive pooling is implemented using an efficient n-dimensional two-pool gather algorithm, eliminating multiple for-loops and providing robust performance on CPU, GPU, and TPU.
All corresponding unit tests for JAX, TensorFlow, and PyTorch adaptive pooling pass successfully.
Verified in real training model tests — both TensorFlow and PyTorch pass on GPU and CPU environments.